MNIST Interpolations

This notebook aims to reproduce Experiment 2 / Figure 7 in the paper.

Setup of plotting library

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import tensorflow as tf
from copy import deepcopy
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
    width=700,
    height=500,
    margin=go.Margin(l=60, r=60, b=40, t=20),
    showlegend=False
)
config={'showLink': False}
colorscale =[[0.0, '#FF881E'], [1.0, '#4E73AE']]

# Make results completely repeatable
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)
/Users/kilian/dev/tum/2018-mlic-kilian/venv/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters

Create the VAE

following the implementation details in appendix D in the paper.

In [2]:
from tensorflow.python.keras import Sequential, Model
from tensorflow.python.keras.layers import Dense, Input, Lambda
from src.vae import VAE
from src.rbf import RBFLayer

# Implementation details from Appendix D
input_dim = 784
latent_dim = 2
l2_reg = tf.keras.regularizers.l2(1e-5)

# Create the encoder models
enc_input = Input((input_dim,))
enc_shared = Dense(64, activation='tanh', kernel_regularizer=l2_reg)
enc_mean = Sequential([
    enc_shared,
    Dense(32, activation='tanh', kernel_regularizer=l2_reg),
    Dense(latent_dim, activation='linear', kernel_regularizer=l2_reg)
])
enc_var = Sequential([
    enc_shared,
    Dense(32, activation='tanh', kernel_regularizer=l2_reg),
    Dense(latent_dim, activation='softplus', kernel_regularizer=l2_reg)
])
enc_mean = Model(enc_input, enc_mean(enc_input))
enc_var = Model(enc_input, enc_var(enc_input))

# Create the decoder models
dec_input = Input((latent_dim,))
dec_mean = Sequential([
    Dense(32, activation='tanh', kernel_regularizer=l2_reg),
    Dense(64, activation='tanh', kernel_regularizer=l2_reg),
    Dense(input_dim, activation='sigmoid', kernel_regularizer=l2_reg)
])
dec_mean = Model(dec_input, dec_mean(dec_input))

# Build the RBF network
num_centers = 64
a = 2.0
rbf = RBFLayer([input_dim], num_centers)
dec_var = Model(dec_input, rbf(dec_input))

vae = VAE(enc_mean, enc_var, dec_mean, dec_var, dec_stddev=1.)
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.

Filter digits 0 and 1 from MNIST

In [3]:
from tensorflow.python.keras.datasets import mnist

# train the VAE on MNIST digits
(x_train_all, y_train_all), _ = mnist.load_data()
train_data = [(x, y) for x, y in zip(x_train_all, y_train_all) if y in [0,1]]
x_train, y_train = zip(*train_data)
                
x_train = np.array(x_train).astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
y_train = np.array(y_train)

# Shuffle the data
p = np.random.permutation(len(x_train))
x_train = x_train[p]
y_train = y_train[p]

Train the VAE

without training the generator's variance network. This will be trained separately later.

In [4]:
history = vae.model.fit(x_train,
              epochs=50,
              batch_size=32,
              validation_split=0.1,
              verbose=0)

# Plot the losses
data = [go.Scatter(y=history.history['loss'], name='Train Loss'),
       go.Scatter(y=history.history['val_loss'], name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)

Visualize the latent space

In [5]:
# Display a 2D plot of the classes in the latent space
sampled, encoded_mean, encoded_var = vae.encoder.predict(x_train)

# Plot
scatter_plot = go.Scatter(
    x = encoded_mean[:300, 0],
    y = encoded_mean[:300, 1],
    mode = 'markers',
    marker = {'color': y_train[:300], 'colorscale': colorscale}
)
data = [scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Train the generator's variance network

For this, we first have to find the centers of the latent points.

In [6]:
from sklearn.cluster import KMeans

# Find the centers of the latent representations
kmeans_model = KMeans(n_clusters=num_centers, random_state=0)
kmeans_model = kmeans_model.fit(encoded_mean)
centers = kmeans_model.cluster_centers_

# Visualize the centers
center_plot = go.Scatter(
    x = centers[:, 0],
    y = centers[:, 1],
    mode = 'markers',
    marker = {'color': 'red'}
)
data = [scatter_plot, center_plot] 
iplot(go.Figure(data=data, layout=layout), config=config)

Compute the bandwidths

In [7]:
# Cluster the latent representations
clustering = dict((c_i, []) for c_i in range(num_centers))
for z_i, c_i in zip(encoded_mean, kmeans_model.predict(encoded_mean)):
    clustering[c_i].append(z_i)
    
bandwidths = []
for c_i, cluster in clustering.items():
    if cluster:
        diffs = np.array(cluster) - centers[c_i]
        avg_dist = np.mean(np.linalg.norm(diffs, axis=1))
        bandwidth = 0.5 / (a * avg_dist)**2
    else:
        bandwidth = 0
    bandwidths.append(bandwidth)
bandwidths = np.array(bandwidths)

Train the variance network

In [8]:
# Train the RBF
vae.recompile_for_var_training()
rbf_kernel = rbf.get_weights()[0]
rbf.set_weights([rbf_kernel, centers, bandwidths])

history = vae.model.fit(x_train,
                        epochs=100,
                        batch_size=32,
                        validation_split=0.1,
                        verbose=0)

# Plot the losses
data = [go.Scatter(y=history.history['loss'],
                   name='Train Loss'),
        go.Scatter(y=history.history['val_loss'],
                   name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.

Choose two latent points

for finding a geodesic.

In [9]:
z_start, z_end = encoded_mean[[5,26]]

# Visualize the centers
task_plot = go.Scatter(
    x = [z_start[0], z_end[0]],
    y = [z_start[1], z_end[1]],
    mode = 'markers',
    marker = {'color': 'red'}
)
data = [scatter_plot, task_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Plot the magnification factors

In [10]:
from src.util import wrap_model_in_float64

# Get the mean and std predictors
_, mean_output, var_output = vae.decoder.output
sqrt_layer = Lambda(tf.sqrt)
dec_mean = Model(vae.decoder.input, mean_output)
dec_std = Model(vae.decoder.input, sqrt_layer(var_output))
dec_mean = wrap_model_in_float64(dec_mean)
dec_std = wrap_model_in_float64(dec_std)

session = tf.keras.backend.get_session()
In [11]:
from src.plot import plot_magnification_factor

heatmap_z1 = np.linspace(-4, 4, 200)
heatmap_z2 = np.linspace(-4, 4, 200)
heatmap = plot_magnification_factor(session, 
                                    heatmap_z1,
                                    heatmap_z2, 
                                    dec_mean, 
                                    dec_std, 
                                    additional_data=[task_plot],
                                    layout=layout,
                                    log_scale=True)
Computing Magnification Factors: 100%|██████████| 40000/40000 [04:23<00:00, 151.92it/s]

Find the geodesic

In [12]:
%%time
from src.discrete import find_geodesic_discrete

curve, iterations = find_geodesic_discrete(session, 
                                           z_start, z_end, 
                                           dec_mean, 
                                           std_generator=dec_std,
                                           num_nodes=50,
                                           max_steps=400,
                                           learning_rate=0.01)
print('-' * 20)
Step 0, Length 20.538120, Energy 341.383851, Max velocity ratio 14.399030
Step 20, Length 20.465466, Energy 255.944517, Max velocity ratio 5.714509
Step 40, Length 20.227591, Energy 227.494427, Max velocity ratio 3.521500
Step 60, Length 19.978767, Energy 211.771123, Max velocity ratio 2.316243
Step 80, Length 19.820610, Energy 203.074646, Max velocity ratio 1.683871
Step 100, Length 19.730676, Energy 198.457263, Max velocity ratio 1.668694
Step 120, Length 19.661281, Energy 195.181052, Max velocity ratio 1.414853
Step 140, Length 19.606049, Energy 193.237314, Max velocity ratio 1.295118
Step 160, Length 19.560300, Energy 192.281555, Max velocity ratio 1.473736
Step 180, Length 19.521500, Energy 190.912544, Max velocity ratio 1.201724
Step 200, Length 19.488283, Energy 190.924876, Max velocity ratio 1.522711
Step 220, Length 19.460113, Energy 189.520393, Max velocity ratio 1.132617
Step 240, Length 19.435931, Energy 188.934401, Max velocity ratio 1.065851
Step 260, Length 19.416425, Energy 188.883886, Max velocity ratio 1.384958
Step 280, Length 19.400694, Energy 188.340953, Max velocity ratio 1.197396
Step 300, Length 19.387485, Energy 187.974198, Max velocity ratio 1.069289
Step 320, Length 19.376579, Energy 187.746758, Max velocity ratio 1.051708
Step 340, Length 19.367451, Energy 187.566688, Max velocity ratio 1.049686
Step 360, Length 19.361237, Energy 188.838703, Max velocity ratio 1.621744
Step 380, Length 19.354643, Energy 187.417663, Max velocity ratio 1.180795
Step 400, Length 19.350136, Energy 187.253365, Max velocity ratio 1.076053
--------------------
CPU times: user 56min 56s, sys: 11min 32s, total: 1h 8min 29s
Wall time: 11min 22s
In [13]:
from src.plot import plot_latent_curve_iterations

plot_latent_curve_iterations(iterations, [heatmap, scatter_plot], layout,
                             step_size=10)

Appendix

The ODE does not converge without fun_jac

In [14]:
%%time
from src.geodesic import find_geodesic

result, iterations = find_geodesic(session, z_start, z_end, 
                                   dec_mean, std_generator=dec_std, 
                                   initial_nodes=20, max_nodes=1000)
print('-' * 20)
Building Jacobian: 100%|██████████| 4/4 [00:02<00:00,  1.64it/s]
   Iteration    Max residual    Total nodes    Nodes added  
       1          9.90e+00          20             38       
       2          1.09e+02          58             114      
       3          9.97e+01          172            342      
       4          8.07e+02          514          (1026)     
Number of nodes is exceeded after iteration 4, maximum relative residual 8.07e+02.
--------------------
CPU times: user 1d 12h 46min 17s, sys: 50min 11s, total: 1d 13h 36min 29s
Wall time: 4h 57min 3s
In [15]:
from src.plot import plot_latent_curve_iterations

plot_latent_curve_iterations(iterations, [scatter_plot], layout)

Using fun_jac is not feasible

In [ ]:
result, iterations = find_geodesic(session, z_start, z_end, 
                                   dec_mean, std_generator=dec_std, 
                                   initial_nodes=20, max_nodes=1000,
                                   use_fun_jac=True)
Building Jacobian: 100%|██████████| 784/784 [09:56<00:00,  1.31it/s]
Building Jacobian: 100%|██████████| 784/784 [14:37<00:00,  1.12s/it]
Building Jacobian: 100%|██████████| 4/4 [1:05:32<00:00, 983.07s/it]
Building Jacobian:   0%|          | 0/4 [00:00<?, ?it/s]

I had to kill the cell above after a couple of hours of taking 38GB of RAM.